{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" \n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "# import ktrain and ktrain.vision modules\n",
    "import ktrain\n",
    "from ktrain import vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download a PNG version of the **MNIST** dataset from [here](https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz) and set DATADIR to the extracted folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "color_mode detected (grayscale) different than color_mode selected (rgb)\n",
      "Found 60000 images belonging to 10 classes.\n",
      "Found 60000 images belonging to 10 classes.\n",
      "Found 10000 images belonging to 10 classes.\n"
     ]
    }
   ],
   "source": [
    "# load the data with some modest data augmentation\n",
    "# We load as RGB even though we have grayscale images\n",
    "# since some models only support RGB images.\n",
    "DATADIR = 'data/mnist_png'\n",
    "data_aug = vision.get_data_aug(featurewise_center=True, \n",
    "                              featurewise_std_normalization=True,\n",
    "                               rotation_range=15,\n",
    "                              zoom_range=0.1,\n",
    "                              width_shift_range=0.1,\n",
    "                              height_shift_range=0.1)\n",
    "(train_data, val_data, preproc) = vision.images_from_folder(\n",
    "                                              datadir=DATADIR,\n",
    "                                              data_aug = data_aug,\n",
    "                                              train_test_names=['training', 'testing'], \n",
    "                                              target_size=(32,32), color_mode='rgb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is Multi-Label? False\n",
      "wrn22 model created.\n"
     ]
    }
   ],
   "source": [
    "# get a pre-canned 22-layer Wide Residual Network model\n",
    "model = vision.image_classifier('wrn22', train_data, val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get a Learner object\n",
    "learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data, \n",
    "                             workers=8, use_multiprocessing=True, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "simulating training for different learning rates... this may take a few moments...\n",
      "Epoch 1/5\n",
      "937/937 [==============================] - 65s 69ms/step - loss: 6.9835 - acc: 0.1659\n",
      "Epoch 2/5\n",
      "937/937 [==============================] - 59s 63ms/step - loss: 5.1547 - acc: 0.6753\n",
      "Epoch 3/5\n",
      "937/937 [==============================] - 59s 63ms/step - loss: 0.9958 - acc: 0.9507\n",
      "Epoch 4/5\n",
      "756/937 [=======================>......] - ETA: 11s - loss: 1.4162 - acc: 0.8885\n",
      "\n",
      "done.\n",
      "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n"
     ]
    }
   ],
   "source": [
    "# find a good learning rate\n",
    "learner.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEOCAYAAACKDawAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3Xl4VOXd//H3d7KSlS0sScCwCAjIZkBBtNAqpa5UrWutSq3r02qX56ltn7Z279Pa5dfFtha0WuuOWgT3raIoGPZV9h1CErYsZJ3798eMmtIASciZM8vndV1zZebMmbm/dyb5zJn7nLmPOecQEZH4F/C7ABERiQwFvohIglDgi4gkCAW+iEiCUOCLiCQIBb6ISIJQ4IuIJAgFvohIglDgi4gkCAW+iEiCSPa7gOa6d+/uioqK/C5DRCRmLFq0qNw5l9eadaMq8IuKiigpKfG7DBGRmGFmW1u7roZ0REQShAJfRCRBKPBFRBKEAl9EJEEo8EVEEoRngW9mg81sabPLITO706v2RETk2Dw7LNM59wEwCsDMkoCdwDNetPXamlKCDgIGATMs/DN0gUDASEkyUpICJAcCpCYbyYEAKckBUgJGclLgo/tTkgIkBcyLMkVEfBWp4/A/BWx0zrX6eNG2uP2RxdQ2BDvs+cwIhX/ASEkOvUl89IaRZKSGf4bWaXb9o3XC15vdl5YSICs1mcy0ZDLTksI/k8lMTSYrLZmumal0yUwhLTmpw/ohItJcpAL/SuBRr5786VvPJOgcQedoCjoc4Jwj6CAYDC1rCDoaGoM0BoM0NDkamoI0Njnqm4I0NoWXBYM0NLoj1glS3+TC6wRpCLqP1w8vq28MUl3X+PFjgh/f92EbdQ1B6puO/6b0cfin0i0zlS4ZqXTLCv/MTP3ovrysNHrmpukNQkRazfPAN7NU4CLgW0e5/ybgJoC+ffu2q42h+TntLS+i6huD1NQ3UlXXSHVdE1V1jdTUN1JZ28j+mnr2VdWzr6aefdWhS+mhWtbuPkRFdT11jS2/WeRlp1HQuROFXToxIC+LAT2yGJiXRf+8TNJT9GYgIh8z55y3DZhdDNzunJtyvHWLi4udplb4T845Djc0UVFVz/6aeiqq6ymrrGP3gVp2HTjMroOH2VpRw479NQTDL6cZH78J5GUxsEfo5ym9s8lOT/G3QyLSYcxskXOuuDXrRmJI5yo8HM5JBGZGRmoyGV2T6dM146jr1TY0saWimg17q9i4t5oNZVVs3FvFe5sqPtrHYQaDe2ZTXNSF007qQvFJXSns0gkz7agWiXeebuGbWSawDejvnDt4vPW1he+NYNCx88BhNuytYtmOAyzaup8l2w5QVdcIQM+cNCYOzOPsQd056+Q8umam+lyxiLRW1GzhO+eqgW5etiHHFwgYfbpm0KdrBpOH9ACgKehYV1pJydb9LNhUwWtrS5m1eAdmcHq/rlwwIp+pw3vRPSvN5+pFpKN4PobfFtrC909T0LFi50FeX7uXuct3sbGsmoDB+AHduOy0Qs4/NZ/UZH0xWyTatGULX4Ev/8E5xwellcxdvpvnlu1iS0UNedlpfOGMk7j69L5001a/SNRQ4EuHCQYd8zaUc//bm/nXujLSkgN8dnQB0yf2Y1DPbL/LE0l4UTOGL7EvEDA+MSiPTwzKY31pJQ/M38LTi3fweMl2po0q4OtTBlHY5ehHDolI9NAWvrTZ/up67pu3ifvf3oxzcN2Ek7h98kA6Z+joHpFIa8sWvvbCSZt1yUzlm1OH8MY3JnHxqHxmvL2ZT/7qXzy7ZCfRtAEhIv9OgS/tlt+5E7/83Ejmfvks+nbN4M7HlzL9b+9TeqjW79JEpAUKfDlhQ/NzmHXrBL53wVDe3VTBlN+8xXPLdvldlogcQYEvHSIpYEyf2I/nv3IW/fMy+fKjS/jfZ1dQ29Dkd2kiEqbAlw7VPy+LJ24ez81n9+fh97Zx2Z/ns+vAYb/LEhEU+OKBlKQA3zrvFGZ8oZit5TVc+qf5rCut9LsskYSnwBfPnDO0J4/fPJ6moOOyP81n4eZ9fpckktAU+OKpofk5PH3bBLpnp/H5mQt4YcVuv0sSSVgKfPFcYZcMZt0ygeH5Odz2yGL+/u4Wv0sSSUgKfImILpmp/OPGM/jUkB5895+reHD+Fr9LEkk4CnyJmE6pSfzp86cxZWhPvj97lbb0RSJMgS8RlZIU4I/XjOGcU3ryvdmreF5j+iIRo8CXiEtJCvD7q0Yzpm8X7nxsKe9tqvC7JJGEoMAXX3RKTWLmdcX07ZbBlx4sYc3uQ36XJBL3FPjim84ZqTw4fRyZacnc8MD7lFfV+V2SSFxT4IuvCjp3Yub1xeyvqeeOx5bQFNT0yiJe8TTwzayzmT1lZmvNbI2ZjfeyPYlNw/Jz+dHFw3lnQwW/eWWd3+WIxC2vt/D/H/Cic24IMBJY43F7EqMuH9uHy4sL+cMbG3h9banf5YjEJc8C38xygbOBmQDOuXrn3AGv2pPY98OLhzO0dw7feHI5+6rr/S5HJO54uYXfDygDHjCzJWY2w8wyPWxPYlx6ShK/vmIkhw438OM5q/0uRyTueBn4ycAY4E/OudFANXDXkSuZ2U1mVmJmJWVlZR6WI7FgSK8cbp00gKeX7OStdfp7EOlIXgb+DmCHc25B+PZThN4A/o1z7j7nXLFzrjgvL8/DciRW3D55IP3zMvn2MyuoqW/0uxyRuOFZ4Dvn9gDbzWxweNGnAH1Ol+NKT0niZ589lR37D/Prl3XUjkhH8foonS8D/zCz5cAo4Kcetydx4vT+3bj69L7c/85mlu/Qvn6RjuBp4DvnloaHa0Y456Y55/Z72Z7El7s+M4TuWWl8c9YKGpqCfpcjEvP0TVuJWjnpKfzw4uGs2X2Iv87b5Hc5IjFPgS9RberwXkwd1ov/9+p6tlXU+F2OSExT4EvUu/uiYZjBr1/5wO9SRGKaAl+iXq/cdG44sx//XLaL1bs0jbJIeynwJSbccvYAstOS+eVLa/0uRSRmKfAlJuRmpHDb5IG88UEZC3SGLJF2UeBLzLhufBE9c9L4xUsf4JzmzRdpKwW+xIxOqUnc8alBLNq6n1fX7PW7HJGYo8CXmHJ5cSFF3TL43WvrtZUv0kYKfIkpyUkBbvnEAFbsPMjbG8r9LkckpijwJeZ8dkwBPXPSuPeNjX6XIhJTFPgSc9KSk7hxYn/e3VTBkm2ankmktRT4EpOuOr0vuZ1SuPdNbeWLtJYCX2JSVloy100o4pXVpawvrfS7HJGYoMCXmHX9hCLSkgPc/85mv0sRiQkKfIlZXTNTuWRMIbMW76Siqs7vckSingJfYtoXJxZR3xjk4fe2+V2KSNRT4EtMG9gjm8mD8/j7e1uobWjyuxyRqKbAl5h341n9Ka+qZ/ayXX6XIhLVFPgS8yYM6MaQXtnMnLdZ0y2IHIMCX2KemTF9Yj8+KK3k3Y2aOlnkaBT4EhcuGplPt8xU7n9ni9+liEQtTwPfzLaY2QozW2pmJV62JYktPSWJq8b15bW1pezYr5Odi7QkElv4k51zo5xzxRFoSxLYleP6APD4+9t9rkQkOmlIR+JGYZcMJg/uwePvb6ehKeh3OSJRx+vAd8DLZrbIzG7yuC0Rrh7Xl72VdbymM2KJ/AevA3+ic24M8BngdjM7+8gVzOwmMysxs5KysjKPy5F4N2lwHr1z0/nHgq1+lyISdTwNfOfczvDPvcAzwLgW1rnPOVfsnCvOy8vzshxJAMlJAa4c25d568vZVqGdtyLNeRb4ZpZpZtkfXgemACu9ak/kQ1eM7UNSwHhkoebXEWnOyy38nsDbZrYMWAjMdc696GF7IgD0yk3nU0N68GTJduobtfNW5EOeBb5zbpNzbmT4Msw59xOv2hI50tWn96Wiup6XVu3xuxSRqKHDMiUunX1yHoVdOvHIAg3riHxIgS9xKRAwrhrXl3c3VWjnrUiYAl/i1rTRBQD8c+lOnysRiQ4KfIlbBZ07cXq/rjyzdKemTRZBgS9xbtroAjaVVbNy5yG/SxHxnQJf4tp5w3uTmhTgmSUa1hFR4Etcy81I4ZNDejB72S4aNaGaJDgFvsS9aaPzKa+qY77OhiUJToEvcW/S4B7kpCfzrIZ1JMEp8CXupackcf6I3ry4ag819Y1+lyPiGwW+JISLRxVQU9/EK6tL/S5FxDcKfEkI44q6kp+brmEdSWgKfEkIgYBx8egC3lpfTnlVnd/liPhCgS8JY9qoApqCjrnLd/tdiogvFPiSMAb3yuaU3jn6EpYkLAW+JJTPjs5n6fYDbCmv9rsUkYhT4EtCuWhkAWbwrGbQlASkwJeE0is3nfH9u/HsEs2gKYlHgS8JZ9roArZU1LBsx0G/SxGJKAW+JJypw3uRmhzQMfmScBT4knBy0lM495SePLdsFw2aQVMSiAJfEtK00QVUVNfz9oZyv0sRiRjPA9/MksxsiZnN8botkdb6xKA8OmekaFhHEkoktvDvANZEoB2RVktNDnD+qb15eVUp1XWaQVMSQ6sC38zuMLMcC5lpZovNbEorHlcInA/MONFCRTratNEFHG5o4uXVe/wuRSQiWruFP905dwiYAnQBrgV+3orH/Rb4H0B7xiTqnNa3C/m56cxZprl1JDG0NvAt/PM84O/OuVXNlrX8ALMLgL3OuUXHWe8mMysxs5KysrJWliNy4gIB44KR+by1voyDNQ1+lyPiudYG/iIze5lQ4L9kZtkcf6v9TOAiM9sCPAZ80swePnIl59x9zrli51xxXl5eG0oXOXEXjsinocnx0ioN60j8a23gfxG4CxjrnKsBUoAbjvUA59y3nHOFzrki4Ergdefc50+kWJGONrwgh6JuGTy3fJffpYh4rrWBPx74wDl3wMw+D/wvoO+lS8wzMy4cmc87G8rZW1nrdzkinmpt4P8JqDGzkcDXgY3AQ61txDn3pnPugnbUJ+K5i0cVEHQwe6m28iW+tTbwG11oasGLgT845/4IZHtXlkjkDOyRxcjCXP6pwJc419rArzSzbxE6HHOumQUIjeOLxIXzTu3Nip0H2b6vxu9SRDzT2sC/AqgjdDz+HqAQ+KVnVYlE2Hmn9gbg+RU6Jl/iV6sCPxzy/wByw8fX1zrnWj2GLxLt+nTNYERhrgJf4lprp1a4HFgIfA64HFhgZpd5WZhIpH1meG+W7dCwjsSv1g7pfIfQMfjXOee+AIwDvutdWSKRd354WOeFldrKl/jU2sAPOOf2Nrtd0YbHisSEvt0yGF6Qw/Mr9K1biU+tDe0XzewlM7vezK4H5gLPe1eWiD/OO7U3S7cfYOeBw36XItLhWrvT9r+B+4AR4ct9zrlvelmYiB/OGx4e1tHOW4lDya1d0Tk3C5jlYS0ivivqnsnQ3jnMXbGbG8/q73c5Ih3qmFv4ZlZpZodauFSa2aFIFSkSSeeP6M2SbQfYpWEdiTPHDHznXLZzLqeFS7ZzLidSRYpEkr6EJfFKR9qIHKFf90yGF+Tw3HIFvsQXBb5ICy4ckc+y7QfYVqEvYUn8UOCLtOD8EaFhHZ0YReKJAl+kBYVdMjjtpC48t0yBL/FDgS9yFBeO6M3aPZWsL630uxSRDqHAFzmK80b0JmBo563EDQW+yFH0yE7njP7dmLNsF6ETvonENgW+yDFcODKfTeXVrNql7xlK7FPgixzD1GG9SA6Ydt5KXFDgixxDl8xUzjq5O3OW7yYY1LCOxDbPAt/M0s1soZktM7NVZvYDr9oS8dJFo/LZeeAwS7bv97sUkRPi5RZ+HfBJ59xIYBQw1czO8LA9EU+cO7QXackBZi/VsI7ENs8C34VUhW+mhC/6TCwxJystmU+d0oO5K3bT2BT0uxyRdvN0DN/MksxsKbAXeMU5t6CFdW4ysxIzKykrK/OyHJF2u2hkPuVV9by3aZ/fpYi0m6eB75xrcs6NAgqBcWY2vIV17nPOFTvnivPy8rwsR6TdJg3uQVZaMrOX7fS7FJF2i8hROs65A8AbwNRItCfS0dJTkpgyrCcvrNxDbUOT3+WItIuXR+nkmVnn8PVOwLnAWq/aE/HapWMKqaxt5KVVe/wuRaRdvNzC7w28YWbLgfcJjeHP8bA9EU+N79+Ngs6deGrRDr9LEWmXVp/EvK2cc8uB0V49v0ikBQLGpWMK+P0bG9h98DC9czv5XZJIm+ibtiJtcMmYQpyDZ5Zo563EHgW+SBsUdc9kbFEXZi3aoRk0JeYo8EXa6NIxhWwsq2bp9gN+lyLSJgp8kTY6b0Rv0pIDzFqsnbcSWxT4Im2Uk57C1OG9mL10l47Jl5iiwBdph0vHFHKotpHX1uz1uxSRVlPgi7TDmQO70zs3nSdKtvtdikirKfBF2iEpYHyuuA9vrS9j+74av8sRaRUFvkg7XTWuDwEzHnp3i9+liLSKAl+knXrnduK8U3vz2MLtVNY2+F2OyHEp8EVOwI0T+1FZ18iTJTpEU6KfAl/kBIzs05kxfTvz4LtbaNJJziXKKfBFTtD0if3YWlHDa2tK/S5F5JgU+CInaOqwXhR07sSMeZv9LkVi0Ix5m7jpoZKInC9ZgS9ygpKTAkyf2I+FW/axZNt+v8uRGPPcsl3srawjOcn7OFbgi3SAK8b2ISc9mb/O2+R3KRJD9hysZdmOg5w7tGdE2lPgi3SArLRkrjnjJF5cuYetFdV+lyMx4sdzVwPw6WG9ItKeAl+kg9wwoYjkQEBb+dIquw4cZs7y3dx8dn8G9siKSJsKfJEO0iMnnUvGFPBkyQ7Kq+r8Lkei3Lz1ZUDoLGqRosAX6UBfOrs/9U1BHpy/xe9SJMq9ta6cXjnpDOoZma17UOCLdKgBeVlMGdqTh97dqukW5KiCQcfbG8o56+TumFnE2vUs8M2sj5m9YWarzWyVmd3hVVsi0eT2yQM5eLiBv+q4fDmKrftqOHi4geKiLhFt18st/Ebg6865ocAZwO1mNtTD9kSiwojCzpx3ai9mzNtEWaXG8uU/rdp1EIBh+bkRbdezwHfO7XbOLQ5frwTWAAVetScSTb4xZTB1jUH+8Pp6v0uRKPK719Zz6t0v8eM5a0hLDnByBMfvIUJj+GZWBIwGFkSiPRG/9c/L4vLiPjyycBvbKnSCFIHlOw7wm1fXUVnbyJ5DteR37kRaclJEa/A88M0sC5gF3OmcO9TC/TeZWYmZlZSVlXldjkjE3HnOySQFjB/PXY1zmkkz0T27ZBepSQGeuW0CAF86q3/Ea/A08M0shVDY/8M593RL6zjn7nPOFTvnivPy8rwsRySieuak81+TB/Ly6lJeWLnH73LEZ/M3llNc1IXRfbuw5efnc/XpfSNeg5dH6RgwE1jjnPu1V+2IRLNbPjGAU3rn8KM5q6mua/S7HPFJRVUda/dUMmFAd1/r8HIL/0zgWuCTZrY0fDnPw/ZEok5yUoAfXTyM3Qdr+eMbG/wuR3yyYPM+AM7o383XOrw8Sudt55w550Y450aFL8971Z5ItCou6srFo/K5/53N7DlY63c54oGmoOPV1aXsPni4xfvnbywnMzWJEYWRPQzzSPqmrUgEfGPKYIJBuOflD/wuRTqYc47/emQxNz5Uwvifvf7RHDnNzd9Ywdh+XUmJwJz3x6LAF4mAPl0zmD6xH08t2sHC8Md7iQ+zl+3ihZV7mDYqH4AvPVTC3sqPP8kdrm9iU1k1Y/pG9lu1LVHgi0TIVz41kILOnbjr6eUc0jw7caGmvpH/fnI5J/fI4jdXjOKBG8ZS2xBk3E9e46F3t3DNjPeY+XZouuyi7pn+FosCXyRiMlKTuedzI9lWUcOdjy0lGNSx+bFs4eZ9nP2LN6hvCvLVcwdhZkwe3IPvXxiaQeZ7/1zFOxsquOfldQCcdpL/W/jJfhcgkkjGD+jG/55/Cnc/t5qH3t3C9Wf287skaaefzF1NeVU908/sx2eGf3zGqhvO7MdZJ3fnnpfWMaR3NukpSTQFHQWdO/lYbYgCXyTCrptQxL/WlfGzF9YyYWB3BvXM9rskaaM5y3exbMdB7r5waItv2gN7ZPPna0/zobJj05COSISZGb+4bCRZacnc+dhS6hqb/C5J2mDPwVq+++xKRhTm8vkzTvK7nDZR4Iv4IC87jZ9fOoLVuw/x61fW+V2OHMXS7Qc4/aevcvfsVTjnmL+hnDP/73Wq65v41edGkuzzYZZtpSEdEZ+cO7QnV43ry31vbWLK0F5RsVNPPlbX2MSXH11M6aE6/jZ/C+9sKGf93ioAnrxlPCfH4FBcbL09icSZb583hPzcTnz9iaU6VDPKzJi3me37DvPADWO5bdIA9tc0cO7Qnjx7+5lRcUx9e1g0TdtaXFzsSkpK/C5DJKIWbKrgmhkLOKN/Nx64Yazv38ZMdIfrm7jjsSW8vLqUcf268vhNZ0T0vLNtZWaLnHPFrVlXf1kiPju9fzd+esmpvL2hnO88s0Jz5/vs/nc28/LqUr44sR8zryuO6rBvK43hi0SBy4v7sH1fDb9/fQNF3TO5bdJAv0tKSI1NQR5+bytnDuzGdy+Iv1NwawtfJEp87dxBXDgyn1++9AHPLdvldzkJaeHmfew+WMvnT4+twy1bS1v4IlHCzPjFpSMoPVjLVx9fSmZaEp8c0tPvshLK2xvKSQ4YZw2Kz7PvaQtfJIp0Sk1i5vXFnNI7h1sfXsyrq0v9LilhNDYFeW75Lsb160pWWnxuCyvwRaJMdnoKD04fx0ndMrj54UUs2FThd0kJYe6K3Wzfd5hrY+zbs22hwBeJQl0zU3ni5vEUdO7ELQ8vYsPeSr9LimuHahv46fNrGNgjiynDeh3/ATFKgS8SpTpnpPLQ9HEkBYypv53H39/bqkM2PRAMOr70YAl7K+v4ybThJAXi5zDMIynwRaJYUfdMnrntTAq7dOK7z67k28+spEnz6HeYAzX13PhQCQs27+PuC4dxus8nGfeaAl8kyvXpmsHrX5/EbZMG8OjCbdz68CJqGzTD5omqqmtk2h/fYd76Mn5w0TC+MD5+x+4/FJ+7okXiTCBg/M/UIeRlp/HDOav5wsyF/PGaMeRlp/ldWsxZtv0A72/Zx+Pvb2dLRQ0PTR/H2XF6GOaRPAt8M7sfuADY65wb7lU7IonkhjP70T0rja89sZRP3vMmX5syiGtOP4nUZH1Yb413NpRzwwPvU98UBOD6CUUJE/bg4eRpZnY2UAU81NrA1+RpIq2zvrSS789exfyNoUM2f3bJqVw5tk9czfvS0Q7XN3HB7+exr7qeWbdOIDs9JS4+IUXF5GnOubeAfV49v0giO7lnNo986QweuH4s/fMy+dbTK/jKY0tpCG+5ysfqG4PcNWs5p3zvRTaWVfPNqUPon5cVF2HfVhrDF4lhk4f0YPyAbtz7xgZ+9/oGSrbs47dXjIr7o01aq6yyjiv+8i6byqsZ1DOLr08ZzKfj+Dj74/F94M/MbjKzEjMrKSsr87sckZiTnpLE16YM5rdXjKKqtpFrZizgj29sIJjgh29uKa/ms/e+w66Dh/nm1CHM+fJZCR324PEJUMysCJijMXyRyDhU28Dt/1jMvPXlnHNKD359xShy0lP8LstzZZV1LNhcwXubKiivrGf3oVqWbT+AGdx3bTHnDo3fSejaMoavIR2ROJKTnsLfbhjHA+9s5ucvrOWSe+fzw4uGMWFgd79L88zKnQe5duYC9tc0kJoUoCEYxDnITk/md1eNZvLgHn6XGDW8PCzzUWAS0N3MdgDfd87N9Ko9EQlJChg3ntWfIb1y+Oas5Vw9YwGfP6Mv37tgWFwdvnmotoGvPb6MV9eUEjD41edGctGofJIDRlPQ0Rh0pKck+V1mVNE5bUXiWG1DEz+Zu4a/v7eVbpmpXD62D9dPKKJHdlpMH8LZFHRc9uf5LNl2gHH9unLPZSPp2y3D77J80ZYhHQW+SAJ444O9/PWtTby7qSI03JGWzLTRBVwypoDRfbv4XV6r7a2s5a5ZK3h97V4A/mvyQL7x6cE+V+UvBb6ItGjD3kpeX7uXhZv38+qaUszguvFF3HnOyXTOSPW7vGP617oyvvLoEg4ebuCM/l25dEwhl51WGNOfVDqCAl9Ejqu8qo57XvqAJ0q2k5WWzCVjCrm8uA+n9M6OeIjuPHCYh97dwvLtB8lITeLcoT0Z2COL/M6deOODvWwuq+bv722le1Yaf7n2NIYX5Ea0vmimwBeRVlu96xB/+tdGXlq5h/qmIKnJATp3SmFwr2x65aRz3YSiDgtY59y/vZkcrGngN6+u49GF26hrDNIpJYnDR5kJdGSfzsy8rpjuWYn3DdljUeCLSJvtr65nzordLN12gKq6Bkq27Keiup7UpADXjj+Jq0/vS+/cdDJSW39wX31jkNfWlPLo+9t5a13oi5WDemZR2CWD5TsOUl5VB8BFI/O545yTGZCXRWNTkLV7KnltzV4yUpN4d1MF/zN1MEN65XjS71inwBeRDlFeVccvXlzLU4t2EHQQMOiamcb5p/bi/BH5jC0K7fA1M+obg6QkGVsranht7V5eW1P60eRuuZ1SKOzSieH5uSzdfoAPSkOnbLx+QhGfGJTH5CE6Vr69FPgi0qHWlVby9vpyDtTU88jCbVRU1+Mc5GWncaAmdL0x6MhKS6aqrhEIvTmc0juH6yYUMW1UwUffAQgGHf9aV8awghx6ZKf72a24oMAXEU/tq67n2SU7WbR1P3nZaWSmJVFd18TGsiqG5udw5di+9Oue6XeZCUFTK4iIp7pmpjJ9Yj+mT+zndynSBvHzPWsRETkmBb6ISIJQ4IuIJAgFvohIglDgi4gkCAW+iEiCUOCLiCQIBb6ISIKIqm/amlkZsDV8Mxc42Mrr3YHydjbb/Pnauk5Ly49cFo/9aH47mvpxrDpbc7ulPnnZl7b0o6Vlx/t7ilQ/jrVOov6PRLIfJznn8lr1SOdcVF6A+1p7HSjpiHbauk5Ly49cFo/9OKLmqOlHW+o+zuvQfJlnfWlLP9rztxWpfnTk31a8/I/43Y+jXaJ5SOe5Nl7viHbauk5Ly49cFo/9aH47mvrR0n1tud1Sn07E8Z6nLf1oadnx/p4i1Y9jrZOo/yN+96NFUTWk015mVuJaOXlQNFM/ok9Gs99uAAAIIElEQVS89EX9iC5+9SOat/Db4j6/C+gg6kf0iZe+qB/RxZd+xMUWvoiIHF+8bOGLiMhxKPBFRBKEAl9EJEHEfeCb2Vlm9mczm2Fm8/2up73MLGBmPzGz35vZdX7X015mNsnM5oVfk0l+13MizCzTzErM7AK/a2kvMzsl/Fo8ZWa3+l3PiTCzaWb2VzN73Mym+F1Pe5lZfzObaWZPdfRzR3Xgm9n9ZrbXzFYesXyqmX1gZhvM7K5jPYdzbp5z7hZgDvCgl/UeTUf0A7gYKAQagB1e1XosHdQPB1QB6cR2PwC+CTzhTZXH10H/H2vC/x+XA2d6We+xdFBfnnXOfQm4BbjCy3qPpoP6sck590VPCmzvN7YicQHOBsYAK5stSwI2Av2BVGAZMBQ4lVCoN7/0aPa4J4DsWO0HcBdwc/ixT8VwPwLhx/UE/hHD/TgXuBK4HrggVvsRfsxFwAvA1X70oyP7En7cr4AxcdCPDv8/j+qTmDvn3jKzoiMWjwM2OOc2AZjZY8DFzrmfAS1+tDazvsBB51ylh+UeVUf0w8x2APXhm03eVXt0HfV6hO0H0ryo83g66PWYBGQS+sc9bGbPO+eCXtZ9pI56PZxzs4HZZjYXeMS7io+ug14TA34OvOCcW+xtxS3r4P+RDhfVgX8UBcD2Zrd3AKcf5zFfBB7wrKL2aWs/ngZ+b2ZnAW95WVgbtakfZnYJ8GmgM/AHb0trkzb1wzn3HQAzux4oj3TYH0NbX49JwCWE3nyf97Sytmvr/8iXgXOAXDMb6Jz7s5fFtUFbX5NuwE+A0Wb2rfAbQ4eIxcBvM+fc9/2u4UQ552oIvXHFNOfc04TevOKCc+5vftdwIpxzbwJv+lxGh3DO/Q74nd91nCjnXAWh/RAdLqp32h7FTqBPs9uF4WWxRv2ILupH9ImXvkRNP2Ix8N8HTjazfmaWSmjH2Wyfa2oP9SO6qB/RJ176Ej398GuvfCv3eD8K7ObjQxG/GF5+HrCO0J7v7/hdp/qhfqgf6kss9EOTp4mIJIhYHNIREZF2UOCLiCQIBb6ISIJQ4IuIJAgFvohIglDgi4gkCAW+tJuZVUWgjYtaOVVxR7Y5ycwmtONxo81sZvj69WYWFXMFmVnRkdP1trBOnpm9GKmaxB8KfPGdmSUd7T7n3Gzn3M89aPNY80hNAtoc+MC3idG5XJxzZcBuM/NtTnzxngJfOoSZ/beZvW9my83sB82WP2tmi8xslZnd1Gx5lZn9ysyWAePNbIuZ/cDMFpvZCjMbEl7voy1lM/ubmf3OzOab2SYzuyy8PGBm95rZWjN7xcye//C+I2p808x+a2YlwB1mdqGZLTCzJWb2qpn1DE9tewvwVTNbaqEzpuWZ2axw/95vKRTNLBsY4Zxb1sJ9RWb2evh381p4um7MbICZvRfu749b+sRkobNqzTWzZWa20syuCC8fG/49LDOzhWaWHW5nXvh3uLilTylmlmRmv2z2Wt3c7O5ngWtafIElPvj9VWRdYvcCVIV/TgHuA4zQRsQc4OzwfV3DPzsBK4Fu4dsOuLzZc20Bvhy+fhswI3z9euAP4et/A54MtzGU0BzjAJcRmto3APQiNNf+ZS3U+yZwb7PbXeCjb5vfCPwqfP1u4BvN1nsEmBi+3hdY08JzTwZmNbvdvO7ngOvC16cDz4avzwGuCl+/5cPf5xHPeynw12a3cwmdRGMTMDa8LIfQzLcZQHp42clASfh6EeETcgA3Af8bvp4GlAD9wrcLgBV+/13p4t0lIaZHFs9NCV+WhG9nEQqct4CvmNlnw8v7hJdXEDqJy6wjnufDaZMXEZqjvSXPutDc86vNrGd42UTgyfDyPWb2xjFqfbzZ9ULgcTPrTShENx/lMecAQ0Pn1wAgx8yynHPNt8h7A2VHefz4Zv35O/CLZsunha8/AtzTwmNXAL8ys/8D5jjn5pnZqcBu59z7AM65QxD6NAD8wcxGEfr9Dmrh+aYAI5p9Asol9JpsBvYC+Ufpg8QBBb50BAN+5pz7y78tDJ1c4xxgvHOuxszeJHQuW4Ba59yRZ+6qC/9s4uh/m3XNrttR1jmW6mbXfw/82jk3O1zr3Ud5TAA4wzlXe4znPczHfeswzrl1ZjaG0ORbPzaz14BnjrL6V4FSYCShmluq1wh9knqphfvSCfVD4pTG8KUjvARMN7MsADMrMLMehLYe94fDfghwhkftvwNcGh7L70lop2tr5PLxvOTXNVteCWQ3u/0yobMpARDegj7SGmDgUdqZT2hKXAiNkc8LX3+P0JANze7/N2aWD9Q45x4GfknofKkfAL3NbGx4nezwTuhcQlv+QeBaQudSPdJLwK1mlhJ+7KDwJwMIfSI45tE8EtsU+HLCnHMvExqSeNfMVgBPEQrMF4FkM1tD6Fyj73lUwixCU9GuBh4GFgMHW/G4u4EnzWwRUN5s+XPAZz/caQt8BSgO7+RcTQtnI3LOrSV0ar3sI+8j9GZxg5ktJxTEd4SX3wl8Lbx84FFqPhVYaGZLge8DP3bO1QNXEDrl5TLgFUJb5/cC14WXDeHfP818aAah39Pi8KGaf+HjT1OTgbktPEbihKZHlrjw4Zi6hc4HuhA40zm3J8I1fBWodM7NaOX6GcBh55wzsysJ7cC92NMij13PW4ROrr3frxrEWxrDl3gxx8w6E9r5+qNIh33Yn4DPtWH90wjtZDXgAKEjeHxhZnmE9mco7OOYtvBFRBKExvBFRBKEAl9EJEEo8EVEEoQCX0QkQSjwRUQShAJfRCRB/H+GHz2lPY6qbwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learner.lr_plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "begin training using triangular learning rate policy with max lr of 0.002...\n",
      "Epoch 1/1\n",
      "936/937 [============================>.] - ETA: 0s - loss: 1.0951 - acc: 0.9207\n",
      "937/937 [==============================] - 67s 71ms/step - loss: 1.0942 - acc: 0.9207 - val_loss: 0.2455 - val_acc: 0.9932\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f56b02348d0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# train WRN-22 model for a single epoch\n",
    "learner.autofit(2e-3, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get a Predictor object that we can use to classify (potentially unlabeled) images\n",
    "predictor = ktrain.get_predictor(learner.model, preproc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# let's see the class labels and their indices\n",
    "predictor.get_classes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['7']"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# let's try to predict an image depicting a 7\n",
    "predictor.predict_filename('/home/amaiya/data/mnist_png/testing/7/7021.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[9.99729097e-01, 1.02616505e-05, 4.12802947e-05, 1.04568608e-05,\n",
       "        5.07383811e-06, 3.03435208e-05, 1.10756089e-04, 2.12177038e-05,\n",
       "        1.56492970e-05, 2.57711799e-05]], dtype=float32)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# let's try predicting an image showing a 0 and return probabilities for all classes\n",
    "predictor.predict_filename('/home/amaiya/data/mnist_png/testing/0/101.png', return_proba=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 1010 images belonging to 1 classes.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[('3/1020.png', '3'),\n",
       " ('3/1028.png', '3'),\n",
       " ('3/1042.png', '3'),\n",
       " ('3/1062.png', '3'),\n",
       " ('3/1066.png', '3'),\n",
       " ('3/1067.png', '3'),\n",
       " ('3/1069.png', '3'),\n",
       " ('3/1072.png', '3'),\n",
       " ('3/1092.png', '3'),\n",
       " ('3/1095.png', '3')]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# let's predict all images showing a 3 in our validation set\n",
    "predictor.predict_folder('/home/amaiya/data/mnist_png/testing/3/')[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's save the predictor for possible later deployment in an application\n",
    "predictor.save('/tmp/mypredictor')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reload the predictor from a file\n",
    "predictor = ktrain.load_predictor('/tmp/mypredictor')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['7']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# let's use the reloaded predictor to verify it still works correctly\n",
    "predictor.predict_filename('/home/amaiya/data/mnist_png/testing/7/7021.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
